import json
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import pandas as pd
import requests


@dataclass(frozen=True)
class EurostatQuery:
    dataset: str
    params: list[tuple[str, str]]


def _unravel(flat: int, sizes: list[int]) -> list[int]:
    idx = []
    for size in reversed(sizes):
        idx.append(flat % size)
        flat //= size
    return list(reversed(idx))


def _codes_by_pos(dim: dict) -> list[str]:
    index = (dim.get("category", {}) or {}).get("index", {}) or {}
    codes = [None] * len(index)
    for code, pos in index.items():
        if isinstance(pos, int) and 0 <= pos < len(codes):
            codes[pos] = code
    return [c for c in codes if c is not None]


def parse_rows(js: dict[str, Any]) -> list[dict[str, Any]]:
    ids = js.get("id", [])
    sizes = js.get("size", [])
    dims = js.get("dimension", {}) or {}
    if not ids or not sizes:
        return []

    codes_map: dict[str, list[str]] = {}
    for dim_id in ids:
        codes_map[dim_id] = _codes_by_pos(dims.get(dim_id, {}))

    values = js.get("value", {}) or {}
    out = []
    for flat_key, val in values.items():
        try:
            flat = int(flat_key)
        except Exception:
            continue
        coords = _unravel(flat, sizes)
        row = {}
        for dim_id, pos in zip(ids, coords, strict=True):
            codes = codes_map.get(dim_id, [])
            row[dim_id] = codes[pos] if 0 <= pos < len(codes) else None
        row["value"] = val
        out.append(row)
    return out


def fetch(query: EurostatQuery, session: requests.Session) -> dict[str, Any]:
    base = "https://ec.europa.eu/eurostat/api/dissemination/statistics/1.0/data/"
    url = base + query.dataset

    backoff = 1.0
    for attempt in range(6):
        try:
            r = session.get(url, params=query.params, timeout=90, headers={"User-Agent": "ess-research/1.0"})
            if r.status_code in {400, 413, 414, 431, 500, 502, 503}:
                raise requests.HTTPError(f"status={r.status_code}", response=r)
            r.raise_for_status()
            return r.json()
        except (requests.ConnectionError, requests.Timeout, requests.HTTPError, json.JSONDecodeError):
            if attempt == 5:
                raise
            time.sleep(backoff)
            backoff = min(backoff * 2, 15)
    raise RuntimeError("Unexpected retry fallthrough")


def chunk(xs: list[str], n: int) -> list[list[str]]:
    return [xs[i : i + n] for i in range(0, len(xs), n)]


def main() -> None:
    root = Path(__file__).resolve().parents[1]
    out_path = root / "data_external" / "broadband_country_year_eurostat_isoc_cbs.csv"
    out_path.parent.mkdir(parents=True, exist_ok=True)

    ess = pd.read_parquet(root / "outputs" / "ess_r8_r11_min.parquet")
    ess = ess[ess["regunit"].isin([1, 2])].copy()
    countries = sorted(ess["cntry"].dropna().astype(str).unique().tolist())
    years = sorted(int(x) for x in ess["survey_year"].dropna().unique().tolist())

    # Keep years supported by isoc_cbs.
    years = [y for y in years if 2013 <= y <= 2024]

    # Speed tiers in Eurostat isoc_cbs (supply-side coverage, % households).
    # Querying all tiers keeps the pipeline robust and lets us pick the strongest shocks later.
    speeds = [
        "MBPS_GT2",
        "MBPS_GT30",
        "MBPS_GT100",
        "GBPS_GT1",
        "GBPS_GT1_UD",
    ]

    session = requests.Session()

    rows = []
    for year in years:
        for batch in chunk(countries, 10):
            params: list[tuple[str, str]] = [("time", str(year))]
            for c in batch:
                params.append(("geo", c))
            # Filter dimensions to keep payload small and stable
            params.extend([("freq", "A"), ("unit", "PC_HH"), ("terrtypo", "TOTAL")])
            for spd in speeds:
                params.append(("inet_spd", spd))

            js = fetch(EurostatQuery("isoc_cbs", params), session)
            parsed = parse_rows(js)
            for r in parsed:
                rows.append(
                    {
                        "cntry": r.get("geo"),
                        "year": int(r.get("time")) if r.get("time") is not None else None,
                        "inet_spd": r.get("inet_spd"),
                        "coverage_pc_hh": float(r["value"]) if r.get("value") is not None else None,
                        "source": "Eurostat API (isoc_cbs)",
                    }
                )
            time.sleep(0.2)

        print(f"Fetched isoc_cbs: year={year} rows={len(rows):,}", flush=True)

    df = pd.DataFrame(rows)
    df = df.dropna(subset=["cntry", "year", "inet_spd"]).copy()
    df["year"] = pd.to_numeric(df["year"], errors="coerce").astype("Int64")
    df["coverage_pc_hh"] = pd.to_numeric(df["coverage_pc_hh"], errors="coerce")
    df = df.sort_values(["cntry", "year", "inet_spd"])
    df.to_csv(out_path, index=False, encoding="utf-8")
    print(f"Wrote: {out_path} rows={len(df):,} countries={df['cntry'].nunique()} years={df['year'].nunique()}")


if __name__ == "__main__":
    main()
